Chapter 8: Tree-Based Methods

The Basics of Decision Trees
Hitters Dataset Example

The Hitters data is part of the the ISLR package. This dataset can be extracted from the ISLR package using the following syntax.

library (ISLR)
write.csv(Hitters, "Hitters.csv")

We would like to define a function that summarized our regressor tree. This can be done using the following packages.

Package Description
sklearn.tree.export_graphviz Exports a decision tree in the DOT format.
StringIO io.StringIO in Python 3
pydot Visualizes the graph

Regression Trees

The Hitters dataset can be used for predicting a baseball player’s Salary based on Years (the number of years that he has played in the major leagues) and Hits (the number of hits that he made in the previous year).

For the Hitters data, a regression tree for predicting the log salary of a baseball player, based on the number of years that he has played in the major leagues and the number of hits that he made in the previous year. At a given internal node, the label (of the form $X_j < t_k$) indicates the left-hand branch emanating from that split, and the right-hand branch corresponds to $Xj \geq tk$. For instance, the split at the top of the tree results in two large branches. The left-hand branch corresponds to $Years< 4.5$, and the right-hand branch corresponds to $Years\geq 4.5$. The tree has two internal nodes and three-terminal nodes, or leaves. The number in each leaf is the mean of the response for the observations that fall there.

Overall, the tree segments the players into three regions of predictor space:

Prediction via Stratification of the Feature Space

There are two steps:

  1. We divide the predictor space—that is, the set of possible values for $X_1$, $X_2$, $\ldots$ , $X_p$ into $J$ distinct and non-overlapping regions, $R_1$, $R_2$, $\ldots$ , $R_J$. This can be done via finding high-dimensional rectangles (boxes) $R_1$, $R_2$, $\ldots$ , $R_J$ that minimize the RSS, given by $$\sum_{j=1}^{J}\sum_{i \in R_j} \left(y_i − \hat{y}_{R_j}\right)^2,$$
  2. For every observation that falls into the region $R_j$, we make the same prediction, which is simply the mean of the response values for the training observations in $R_j$.

In order to perform recursive binary splitting, we consider all predictors $X_1$, $X_2$, $\ldots$ , $X_p$, and all possible values of the cutpoint s for each of the predictors, and then choose the predictor and cutpoint such that the resulting tree has the lowest RSS.

Tree Pruning

From the textbook, we have an algorithm for building a Regression Tree, Algorithm 8.1.

In terms of python implementations, there is a in-depth article by Sckit-learn regarding Decision Trees.


  1. Use recursive binary splitting to grow a large tree on the training data, stopping only when each terminal node has fewer than some minimum number of observations.
  2. Apply cost complexity pruning to the large tree in order to obtain a sequence of best subtrees, as a function of $\alpha$.
$$\underbrace{\sum_{m=1}^{|T|} \sum_{I:~x_i \in R_m} \left(y_i − \hat{y}_{R_m}\right)^2}_{\mbox{RSS}} + \underbrace{\alpha|T|}_{\mbox{Shrinkage penalty}}$$
  1. Use K-fold cross-validation to choose α. That is, divide the training observations into $K$ folds. For each $k$ = $1$, $\ldots$ , $K$:
    • (a) Repeat Steps 1 and 2 on all but the kth fold of the training data.
    • (b) Evaluate the mean squared prediction error on the data in the left-out $k$th fold, as a function of $\alpha$. Average the results for each value of α, and pick α to minimize the average error.
  2. Return the subtree from Step 2 that corresponds to the chosen value of $\alpha$.

Classification Trees

A classification tree is very similar to a regression tree, except that it is classification tree used to predict a qualitative response rather than a quantitative one.

Heart Example Dataset

These data contain a binary outcome AHD for 303 patients who presented with chest pain. An outcome value of Yes indicates the presence of heart disease based on an angiographic test, while No means no heart disease. There are 13 predictors including Age, Sex, Chol (a cholesterol measurement), and other heart and lung function measurements. Cross-validation results in a tree with six terminal nodes. Dataset available on at this link

We can use Pandas Factorize to encode categorical variables as follows,

$$\mbox{Chest Pain} = \begin{cases} 0,&\mbox{Typical},\\ 1,&\mbox{Asymptomatic},\\ 2,&\mbox{Non-Anginal},\\ 3,&\mbox{Non-Typical}. \end{cases}, \qquad \mbox{Thal} = \begin{cases} 0,&\mbox{Fixed},\\ 1,&\mbox{Normal},\\ 2,&\mbox{Reversable}. \end{cases}, \qquad \mbox{AHD} = \begin{cases} 0,&\mbox{No},\\ 1,&\mbox{Yes}. \end{cases}. $$

$X$ and $y$ sets, and sklearn DecisionTreeClassifier

Trees Versus Linear Models

where $R_1$, . . . , $R_M$ represent a partition of feature space.

Lab

Fitting Classification Trees

Carseats Dataset Example

This dataset can be extracted from ISLR R package.

We create a variable High which takes on a value of Yes if the Sales variable exceeds 8, and takes on a value of No otherwise.

We can encode categorical variables as follows,

$$\mbox{Chest Pain} = \begin{cases} 0,&\mbox{Bad},\\ 1,&\mbox{Medium},\\ 2,&\mbox{Good}. \end{cases}, \qquad \mbox{Urban} = \begin{cases} 0,&\mbox{No},\\ 1,&\mbox{Yes}. \end{cases}, \qquad \mbox{US} = \begin{cases} 0,&\mbox{No},\\ 1,&\mbox{Yes}. \end{cases}, \qquad \mbox{High} = \begin{cases} 0,&\mbox{No},\\ 1,&\mbox{Yes}. \end{cases}. $$

Now,

Fitting Regression Trees

Boston Dataset

This dataset can be extracted from the MASS library using the following syntax.

library (MASS)
write.csv(Boston, "Boston.csv")

This dataset is part of the scikit-learn datasets and can be accessed using the following syntax.

from sklearn.datasets import load_boston

Refrences

  1. James, G., Witten, D., Hastie, T., & Tibshirani, R. (2013). An introduction to statistical learning (Vol. 112, pp. 3-7). New York: springer.
  2. Jordi Warmenhoven, ISLR-python
  3. James, G., Witten, D., Hastie, T., & Tibshirani, R. (2017). ISLR: Data for an Introduction to Statistical Learning with Applications in R